#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pwn import *

from aes_utils import mix_columns, mix_columns_inv, shift_rows, shift_rows_inv

exe = '../chall/target/debug/shuffled-aes'

if args["HOST"] and args["PORT"]:
    host = args["HOST"]
    port = args["PORT"]

def start(argv=[], *a, **kw):
    '''Start the exploit against the target.'''
    if args.LOCAL:
        return process([exe] + argv, *a, **kw)
    else:
        return connect(host, port)

# -- Exploit goes here --
io = start()

io.recvuntil(b'Here\'s your flag: ')

nonce, ct = io.recvline().strip().decode().split()
nonce, ct = bytes.fromhex(nonce), bytes.fromhex(ct)

queries = 0

def get_keystream(num_blocks = (len(ct) + 15) // 16):
    global queries
    queries += 1
    io.recvuntil(b'pt> ')
    pt = b'0' * 16 * num_blocks
    io.sendline(pt)
    io.recvuntil(b'ct: ')
    nonce, ct = io.recvline().strip().decode().split()
    nonce, ct = bytes.fromhex(nonce), bytes.fromhex(ct)
    keystream = bytes(p ^ c for p, c in zip(pt, ct, strict=True))
    return nonce, keystream

def get_pairs(num_blocks = (len(ct) + 15) // 16):
    nonce, keystream = get_keystream(num_blocks)

    res = []
    for i in range(num_blocks):
        pt = nonce + i.to_bytes(4, 'big')
        ct = bytearray(keystream[i * 16 : (i + 1) * 16])
        for _ in range(10):
            mix_columns_inv(ct)
            shift_rows_inv(ct)

        res.append((pt, bytes(ct)))

    return res

sbox_dicts = [dict() for _ in range(16)]

def try_decrypt():
    try:
        flag = b''
        for idx in range(0, len(ct), 16):
            pt = nonce + (idx // 16).to_bytes(4, 'big')
            ks = bytearray(sbox_dicts[i][pt_byte] for i, pt_byte in enumerate(pt))

            for _ in range(10):
                shift_rows(ks)
                mix_columns(ks)

            flag += bytes(c ^ k for c, k in zip(ct[idx:], ks))
    except KeyError:
        return None
    return flag

flag = None
while flag is None:
    for inp, out in get_pairs():
        for idx in range(16):
            assert inp[idx] not in sbox_dicts[idx] or sbox_dicts[idx][inp[idx]] == out[idx]
            sbox_dicts[idx][inp[idx]] = out[idx]

    flag = try_decrypt()

log.info('queried %d queries of %d blocks each', queries, (len(ct) + 15) // 16)
log.success('flag: %s', flag.decode())
